{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# Copyright (c) 2024, NVIDIA CORPORATION.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "#"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\" width=\"90px\">\n",
    "\n",
    "# Distributed Hyperparameter Tuning: Optuna + JoblibSpark\n",
    "\n",
    "\n",
    "This demo demonstrates distributed hyperparameter tuning for XGBoost using the [JoblibSpark backend](https://github.com/joblib/joblib-spark), building on this [example from Databricks](https://docs.databricks.com/en/machine-learning/automl-hyperparam-tuning/optuna.html).  \n",
    "We implement best practices to precompute data and maximize computations on the GPU.  \n",
    "\n",
    "\n",
    "\n",
    "Reference: https://forecastegy.com/posts/xgboost-hyperparameter-tuning-with-optuna/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Note:\n",
    "Before running, please make sure you've followed the relevant [setup instructions](../README.md) for your environment (standalone or databricks).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "import os\n",
    "import requests\n",
    "import joblib\n",
    "from joblibspark import register_spark\n",
    "import optuna\n",
    "from optuna.samplers import TPESampler\n",
    "import xgboost as xgb\n",
    "from pyspark.sql import SparkSession\n",
    "from pyspark import TaskContext, SparkConf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download the dataset\n",
    "\n",
    "We'll use the [red wine quality dataset](https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv) to regress wine quality based on features such as acidity, sugar content, etc.  \n",
    "\n",
    "**Note**: This example uses a small dataset for demonstration purposes. The performance advantages of distributed training are best realized with large datasets and computational workloads."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "cwd = os.getcwd()\n",
    "os.mkdir(os.path.join(cwd, \"data\")) if not os.path.exists(os.path.join(cwd, \"data\")) else None\n",
    "filepath = os.path.join(cwd, \"data\", \"winequality-red.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "File downloaded and saved to /home/rishic/Code/myforks/spark-rapids-examples/examples/ML+DL-Examples/Optuna-Spark/optuna-examples/data/winequality-red.csv\n"
     ]
    }
   ],
   "source": [
    "url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv\"\n",
    "\n",
    "response = requests.get(url)\n",
    "if response.status_code == 200:\n",
    "    with open(filepath, \"wb\") as f:\n",
    "        f.write(response.content)\n",
    "    print(f\"File downloaded and saved to {filepath}\")\n",
    "else:\n",
    "    print(f\"Failed to download the file. Status code: {response.status_code}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 1. Running Optuna locally"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cudf\n",
    "from cuml.metrics.regression import mean_squared_error\n",
    "from cuml.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fixed acidity</th>\n",
       "      <th>volatile acidity</th>\n",
       "      <th>citric acid</th>\n",
       "      <th>residual sugar</th>\n",
       "      <th>chlorides</th>\n",
       "      <th>free sulfur dioxide</th>\n",
       "      <th>total sulfur dioxide</th>\n",
       "      <th>density</th>\n",
       "      <th>pH</th>\n",
       "      <th>sulphates</th>\n",
       "      <th>alcohol</th>\n",
       "      <th>quality</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>7.4</td>\n",
       "      <td>0.70</td>\n",
       "      <td>0.00</td>\n",
       "      <td>1.9</td>\n",
       "      <td>0.076</td>\n",
       "      <td>11.0</td>\n",
       "      <td>34.0</td>\n",
       "      <td>0.9978</td>\n",
       "      <td>3.51</td>\n",
       "      <td>0.56</td>\n",
       "      <td>9.4</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>7.8</td>\n",
       "      <td>0.88</td>\n",
       "      <td>0.00</td>\n",
       "      <td>2.6</td>\n",
       "      <td>0.098</td>\n",
       "      <td>25.0</td>\n",
       "      <td>67.0</td>\n",
       "      <td>0.9968</td>\n",
       "      <td>3.20</td>\n",
       "      <td>0.68</td>\n",
       "      <td>9.8</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>7.8</td>\n",
       "      <td>0.76</td>\n",
       "      <td>0.04</td>\n",
       "      <td>2.3</td>\n",
       "      <td>0.092</td>\n",
       "      <td>15.0</td>\n",
       "      <td>54.0</td>\n",
       "      <td>0.9970</td>\n",
       "      <td>3.26</td>\n",
       "      <td>0.65</td>\n",
       "      <td>9.8</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>11.2</td>\n",
       "      <td>0.28</td>\n",
       "      <td>0.56</td>\n",
       "      <td>1.9</td>\n",
       "      <td>0.075</td>\n",
       "      <td>17.0</td>\n",
       "      <td>60.0</td>\n",
       "      <td>0.9980</td>\n",
       "      <td>3.16</td>\n",
       "      <td>0.58</td>\n",
       "      <td>9.8</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>7.4</td>\n",
       "      <td>0.70</td>\n",
       "      <td>0.00</td>\n",
       "      <td>1.9</td>\n",
       "      <td>0.076</td>\n",
       "      <td>11.0</td>\n",
       "      <td>34.0</td>\n",
       "      <td>0.9978</td>\n",
       "      <td>3.51</td>\n",
       "      <td>0.56</td>\n",
       "      <td>9.4</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   fixed acidity  volatile acidity  citric acid  residual sugar  chlorides  \\\n",
       "0            7.4              0.70         0.00             1.9      0.076   \n",
       "1            7.8              0.88         0.00             2.6      0.098   \n",
       "2            7.8              0.76         0.04             2.3      0.092   \n",
       "3           11.2              0.28         0.56             1.9      0.075   \n",
       "4            7.4              0.70         0.00             1.9      0.076   \n",
       "\n",
       "   free sulfur dioxide  total sulfur dioxide  density    pH  sulphates  \\\n",
       "0                 11.0                  34.0   0.9978  3.51       0.56   \n",
       "1                 25.0                  67.0   0.9968  3.20       0.68   \n",
       "2                 15.0                  54.0   0.9970  3.26       0.65   \n",
       "3                 17.0                  60.0   0.9980  3.16       0.58   \n",
       "4                 11.0                  34.0   0.9978  3.51       0.56   \n",
       "\n",
       "   alcohol  quality  \n",
       "0      9.4        5  \n",
       "1      9.8        5  \n",
       "2      9.8        5  \n",
       "3      9.8        6  \n",
       "4      9.4        5  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = cudf.read_csv(filepath, delimiter=\";\")\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Prepare the train/validation sets. Precompute the Quantile DMatrix, which is used by histogram-based tree methods to save memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = data.iloc[:, :-1].values\n",
    "y = data[\"quality\"].values\n",
    "X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)\n",
    "Xy_train_qdm = xgb.QuantileDMatrix(X_train, y_train)  # Precompute Quantile DMatrix to avoid repeated quantization every trial."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Objective function\n",
    "\n",
    "We define the objective and a hyperparameter search space to optimize via the `trial.suggest_` methods.  \n",
    "\n",
    "In each trial, new hyperparameters will be suggested based on previous results. See [optuna.trial.Trial](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html) API for a full list of functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def objective(trial):\n",
    "    params = {\n",
    "        \"objective\": \"reg:squarederror\",\n",
    "        \"verbosity\": 0,\n",
    "        \"learning_rate\": trial.suggest_float(\"learning_rate\", 1e-3, 0.1, log=True),\n",
    "        \"max_depth\": trial.suggest_int(\"max_depth\", 1, 10),\n",
    "        \"subsample\": trial.suggest_float(\"subsample\", 0.05, 1.0),\n",
    "        \"colsample_bytree\": trial.suggest_float(\"colsample_bytree\", 0.05, 1.0),\n",
    "        \"min_child_weight\": trial.suggest_int(\"min_child_weight\", 1, 20),\n",
    "        \"tree_method\": \"gpu_hist\",\n",
    "        \"device\": \"cuda\",\n",
    "    }\n",
    "\n",
    "    booster = xgb.train(params=params, dtrain=Xy_train_qdm, num_boost_round=trial.suggest_int(\"num_boost_round\", 100, 500))\n",
    "    predictions = booster.inplace_predict(X_val)\n",
    "    rmse = mean_squared_error(y_val, predictions, squared=False).get()\n",
    "    \n",
    "    return rmse   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create the study and optimize. By default, the study results will be stored in memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[I 2024-12-11 23:42:09,341] A new study created in memory with name: optuna-xgboost-local\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[I 2024-12-11 23:42:09,715] Trial 0 finished with value: 0.6377619522504244 and parameters: {'learning_rate': 0.005611516415334507, 'max_depth': 10, 'subsample': 0.7453942447208348, 'colsample_bytree': 0.6187255599871848, 'min_child_weight': 4, 'num_boost_round': 162}. Best is trial 0 with value: 0.6377619522504244.\n",
      "[I 2024-12-11 23:42:10,666] Trial 1 finished with value: 0.6703788974319568 and parameters: {'learning_rate': 0.0013066739238053278, 'max_depth': 9, 'subsample': 0.6210592611560484, 'colsample_bytree': 0.7226689489062432, 'min_child_weight': 1, 'num_boost_round': 488}. Best is trial 0 with value: 0.6377619522504244.\n",
      "[I 2024-12-11 23:42:10,806] Trial 2 finished with value: 0.6181751362616256 and parameters: {'learning_rate': 0.04622589001020832, 'max_depth': 3, 'subsample': 0.2227337188467456, 'colsample_bytree': 0.22423428436076215, 'min_child_weight': 7, 'num_boost_round': 310}. Best is trial 2 with value: 0.6181751362616256.\n",
      "[I 2024-12-11 23:42:10,922] Trial 3 finished with value: 0.6698576232920956 and parameters: {'learning_rate': 0.007309539835912915, 'max_depth': 3, 'subsample': 0.6312602499862605, 'colsample_bytree': 0.18251916761943976, 'min_child_weight': 6, 'num_boost_round': 246}. Best is trial 2 with value: 0.6181751362616256.\n",
      "[I 2024-12-11 23:42:11,039] Trial 4 finished with value: 0.6704590546150145 and parameters: {'learning_rate': 0.008168455894760165, 'max_depth': 8, 'subsample': 0.23969009305044175, 'colsample_bytree': 0.538522716492931, 'min_child_weight': 12, 'num_boost_round': 118}. Best is trial 2 with value: 0.6181751362616256.\n",
      "[I 2024-12-11 23:42:11,191] Trial 5 finished with value: 0.6088806682631155 and parameters: {'learning_rate': 0.016409286730647923, 'max_depth': 2, 'subsample': 0.11179901333601554, 'colsample_bytree': 0.9514412603906666, 'min_child_weight': 20, 'num_boost_round': 424}. Best is trial 5 with value: 0.6088806682631155.\n",
      "[I 2024-12-11 23:42:11,266] Trial 6 finished with value: 0.7103495949713845 and parameters: {'learning_rate': 0.0040665633135147945, 'max_depth': 1, 'subsample': 0.700021375186549, 'colsample_bytree': 0.4681448690526212, 'min_child_weight': 3, 'num_boost_round': 298}. Best is trial 5 with value: 0.6088806682631155.\n",
      "[I 2024-12-11 23:42:11,666] Trial 7 finished with value: 0.7255199474722185 and parameters: {'learning_rate': 0.001171593739230706, 'max_depth': 10, 'subsample': 0.29584098252001606, 'colsample_bytree': 0.6793961701362828, 'min_child_weight': 7, 'num_boost_round': 308}. Best is trial 5 with value: 0.6088806682631155.\n",
      "[I 2024-12-11 23:42:11,829] Trial 8 finished with value: 0.6060010014477214 and parameters: {'learning_rate': 0.0123999678368461, 'max_depth': 2, 'subsample': 0.9711053963763306, 'colsample_bytree': 0.7863761821930588, 'min_child_weight': 19, 'num_boost_round': 458}. Best is trial 8 with value: 0.6060010014477214.\n",
      "[I 2024-12-11 23:42:12,168] Trial 9 finished with value: 0.6292433375858283 and parameters: {'learning_rate': 0.015696396388661146, 'max_depth': 10, 'subsample': 0.13406787694932354, 'colsample_bytree': 0.23618371929818793, 'min_child_weight': 1, 'num_boost_round': 230}. Best is trial 8 with value: 0.6060010014477214.\n"
     ]
    }
   ],
   "source": [
    "study = optuna.create_study(study_name=\"optuna-xgboost-local\", sampler=TPESampler(seed=42))\n",
    "study.optimize(objective, n_trials=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best RMSE:  0.6060010014477214\n",
      "Best hyperparameters:  {'learning_rate': 0.0123999678368461, 'max_depth': 2, 'subsample': 0.9711053963763306, 'colsample_bytree': 0.7863761821930588, 'min_child_weight': 19, 'num_boost_round': 458}\n"
     ]
    }
   ],
   "source": [
    "trial = study.best_trial\n",
    "print(\"Best RMSE: \", trial.value)\n",
    "print(\"Best hyperparameters: \", trial.params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 2. Distributed Optuna on Spark "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PySpark\n",
    "\n",
    "For standalone users, we need to create the Spark session. For Databricks users, the Spark session will be preconfigured."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "24/12/11 23:42:12 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n",
      "24/12/11 23:42:12 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
      "Setting default log level to \"WARN\".\n",
      "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
      "24/12/11 23:42:13 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
     ]
    }
   ],
   "source": [
    "def initialize_spark():\n",
    "    import socket\n",
    "    hostname = socket.gethostname()\n",
    "    conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
    "\n",
    "    conf = SparkConf()\n",
    "    conf.setMaster(f\"spark://{hostname}:7077\")  # Assuming master is on host and default port. \n",
    "    conf.set(\"spark.task.maxFailures\", \"1\")\n",
    "    conf.set(\"spark.task.resource.gpu.amount\", \"1\")\n",
    "    conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n",
    "    conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
    "    conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
    "    \n",
    "    spark = SparkSession.builder.appName(\"optuna-joblibspark-xgboost\").config(conf=conf).getOrCreate()\n",
    "    return spark\n",
    "\n",
    "if 'spark' not in globals():\n",
    "    spark = initialize_spark()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Optuna Task\n",
    "\n",
    "This implementation demonstrates **Worker I/O**. \n",
    "\n",
    "This means that each worker will read the full dataset from the filepath rather than passing the data in a dataframe.  \n",
    "In practice, this requires the dataset to be written to a distributed file system accessible to all workers prior to tuning. \n",
    "\n",
    "For the alternative implementation using **Spark I/O**, see the [Spark Dataframe notebook](optuna-dataframe.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the task, each worker will:\n",
    "1. Read the dataset from the filepath\n",
    "2. Load the study from the MySQL storage backend\n",
    "3. Optimize over the objective for the assigned number of trials, sending results back to the database after each iteration\n",
    "\n",
    "Here we use Optuna's [Define-and-Run](https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/009_ask_and_tell.html#define-and-run) API, which allows us to predefine the hyperparameter space and pass it to the task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def task(num_trials: int, xgb_params: dict, optuna_params: dict, driver_ip: str, study_name: str, seed: int, filepath: str):\n",
    "    import cudf\n",
    "    from cuml.metrics.regression import mean_squared_error\n",
    "    from cuml.model_selection import train_test_split\n",
    "\n",
    "    tc = TaskContext.get()\n",
    "    assert \"gpu\" in tc.resources(), \"GPU resource not found.\"\n",
    "\n",
    "    if filepath.startswith(\"/dbfs/\"):\n",
    "        # Check to ensure GPU direct storage is disabled for cuDF on databricks.\n",
    "        libcudf_policy = os.environ.get('LIBCUDF_CUFILE_POLICY')\n",
    "        if libcudf_policy != 'OFF':\n",
    "            raise RuntimeError(\"Set LIBCUDF_CUFILE_POLICY=OFF to read from DBFS with cuDF.\")\n",
    "    \n",
    "    data = cudf.read_csv(filepath, delimiter=\";\")\n",
    "    X = data.iloc[:, :-1].values\n",
    "    y = data[\"quality\"].values\n",
    "    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=seed)\n",
    "\n",
    "    tuning_max_bin = \"max_bin\" in optuna_params\n",
    "    if not tuning_max_bin:\n",
    "        max_bin = xgb_params.get(\"max_bin\", 256)\n",
    "        # Precompute Quantile DMatrix to avoid repeated quantization every trial.\n",
    "        Xy_train_qdm = xgb.QuantileDMatrix(X_train, y_train, max_bin=max_bin)\n",
    "\n",
    "    study = optuna.load_study(\n",
    "        study_name=study_name,\n",
    "        storage=f\"mysql://optuna_user:optuna_password@{driver_ip}/optuna\",\n",
    "        sampler=TPESampler(seed=seed),\n",
    "    )\n",
    "\n",
    "    print(f\"Running {num_trials} trials on partition {tc.partitionId()}.\")\n",
    "\n",
    "    ### Objective ###\n",
    "    for _ in range(num_trials):\n",
    "        trial = study.ask(optuna_params)\n",
    "        xgb_params.update(trial.params)\n",
    "\n",
    "        if tuning_max_bin:\n",
    "            # If tuning the max_bin param, we must recompute the QDM every trial.\n",
    "            if \"n_estimators\" not in xgb_params:\n",
    "                xgb_params[\"n_estimators\"] = 100  # Default value if not tuning.\n",
    "\n",
    "            model = xgb.XGBRegressor(**xgb_params)\n",
    "            model.fit(X_train, y_train)\n",
    "            booster = model.get_booster()\n",
    "        else:\n",
    "            # Train the model with xgb.train() API using the precomputed QDM.\n",
    "            num_boost_round = xgb_params.get(\"n_estimators\", 100)\n",
    "            booster = xgb.train(params=xgb_params, dtrain=Xy_train_qdm, num_boost_round=num_boost_round)\n",
    "            \n",
    "        # Perform in-place predictions on GPU using the booster.\n",
    "        predictions = booster.inplace_predict(X_val)\n",
    "        rmse = mean_squared_error(y_val, predictions, squared=False).get()\n",
    "        \n",
    "        study.tell(trial, rmse)\n",
    "\n",
    "    return study.best_params, study.best_value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This will register the Spark Session with the Joblib Spark backend.\n",
    "register_spark()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup and run the Optuna study"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the driver IP for the MySQL database.  \n",
    "- For standalone users, make sure you've followed the [database setup instructions](../README.md#setup-database-for-optuna). The database should be on 'localhost'. \n",
    "- For databricks users, the database should already be setup on the driver node by the init script."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check if we're running on databricks\n",
    "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MySQL database is hosted on localhost\n"
     ]
    }
   ],
   "source": [
    "if on_databricks:\n",
    "    driver_ip = spark.conf.get(\"spark.driver.host\")\n",
    "else:\n",
    "    driver_ip = \"localhost\"\n",
    "\n",
    "print(f\"MySQL database is hosted on {driver_ip}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a new study, referencing the MySQL database as the storage backend."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[I 2024-12-11 23:42:13,928] A new study created in RDB with name: optuna-xgboost-joblibspark\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<optuna.study.study.Study at 0x76ae75922b00>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "study_name = \"optuna-xgboost-joblibspark\"\n",
    "seed = 42\n",
    "\n",
    "try:\n",
    "    # Delete the study if it already exists\n",
    "    optuna.delete_study(\n",
    "        study_name=study_name, \n",
    "        storage=f\"mysql://optuna_user:optuna_password@{driver_ip}/optuna\"\n",
    "    )\n",
    "except:\n",
    "    pass\n",
    "\n",
    "optuna.create_study(\n",
    "    study_name=study_name,\n",
    "    storage=f\"mysql://optuna_user:optuna_password@{driver_ip}/optuna\",\n",
    "    sampler=TPESampler(seed=seed)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the number of tasks, number of trials, and trials per task. \n",
    "\n",
    "**NOTE**: for standalone users running on a single worker, the 4 tasks will all be assigned to the same worker and executed sequentially in this demonstration. This can easily be scaled up to run concurrently by adding more workers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def partition_trials(total_trials: int, total_tasks: int) -> List[int]:\n",
    "    base_size = total_trials // total_tasks\n",
    "    extra = total_trials % total_tasks\n",
    "    partitions = [base_size] * total_tasks\n",
    "    for i in range(extra):\n",
    "        partitions[i] += 1\n",
    "    \n",
    "    return partitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trials per task: [25, 25, 25, 25]\n"
     ]
    }
   ],
   "source": [
    "num_tasks = 4\n",
    "num_trials = 100\n",
    "trials_per_task = partition_trials(num_trials, num_tasks)\n",
    "print(f\"Trials per task: {trials_per_task}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Define params\n",
    "Define the XGBoost model params and the hyperparams for Optuna to tune. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Keep these params consistent:\n",
    "xgb_params = {\n",
    "    \"objective\": \"reg:squarederror\",\n",
    "    \"verbosity\": 0,\n",
    "    \"tree_method\": \"gpu_hist\",\n",
    "    \"device\": f\"cuda\",\n",
    "    \"seed\": seed,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Tune these params:\n",
    "optuna_params = {\n",
    "    \"n_estimators\": optuna.distributions.IntDistribution(100, 500),\n",
    "    \"learning_rate\": optuna.distributions.FloatDistribution(1e-3, 0.1, log=True),\n",
    "    \"max_depth\": optuna.distributions.IntDistribution(1, 10),\n",
    "    \"subsample\": optuna.distributions.FloatDistribution(0.05, 1.0),\n",
    "    \"colsample_bytree\": optuna.distributions.FloatDistribution(0.05, 1.0),\n",
    "    \"min_child_weight\": optuna.distributions.IntDistribution(1, 20),\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**For Databricks**: we must download the dataset to DBFS so that all workers can access it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "if on_databricks:\n",
    "    dbutils.fs.mkdirs(\"/FileStore/optuna-data\")\n",
    "    filepath = \"/dbfs/FileStore/optuna-data/winequality-red.csv\"\n",
    "    url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv\"\n",
    "\n",
    "    response = requests.get(url)\n",
    "    if response.status_code == 200:\n",
    "        with open(filepath, \"wb\") as f:\n",
    "            f.write(response.content)\n",
    "        print(f\"File downloaded and saved to {filepath}\")\n",
    "    else:\n",
    "        print(f\"Failed to download the file. Status code: {response.status_code}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run the study\n",
    "\n",
    "Run parallel threads to execute the Optuna task and collect the reuslts (it might take a few minutes)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/rishic/anaconda3/envs/optuna-spark/lib/python3.10/site-packages/joblibspark/backend.py:115: UserWarning: Spark version does not support stage-level scheduling.\n",
      "  warnings.warn(\"Spark version does not support stage-level scheduling.\")\n",
      "/home/rishic/anaconda3/envs/optuna-spark/lib/python3.10/site-packages/joblibspark/backend.py:154: UserWarning: User-specified n_jobs (4) is greater than the max number of concurrent tasks (1) this cluster can run now.If dynamic allocation is enabled for the cluster, you might see more executors allocated.\n",
      "  warnings.warn(f\"User-specified n_jobs ({n_jobs}) is greater than the max number of \"\n",
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "with joblib.parallel_backend(\"spark\", n_jobs=num_tasks):\n",
    "    results = joblib.Parallel()(\n",
    "        joblib.delayed(task)(num_trials,\n",
    "                             xgb_params,\n",
    "                             optuna_params,\n",
    "                             driver_ip,\n",
    "                             study_name,\n",
    "                             seed,\n",
    "                             filepath) for num_trials in trials_per_task\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best parameters: {'n_estimators': 463, 'learning_rate': 0.05206124631137337, 'max_depth': 9, 'subsample': 0.7434942725744815, 'colsample_bytree': 0.877391644494205, 'min_child_weight': 4}\n",
      "Best value: 0.5324732150787205\n"
     ]
    }
   ],
   "source": [
    "best_params = min(results, key=lambda x: x[1])[0]\n",
    "best_value = min(results, key=lambda x: x[1])[1]\n",
    "\n",
    "print(f\"Best parameters: {best_params}\")\n",
    "print(f\"Best value: {best_value}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "optuna-spark",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
